{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0a078071-3cb1-4278-b375-5c45efa100ca",
   "metadata": {},
   "source": [
    "<center>\n",
    "<h1>基于pytorch + LSTM 的古诗生成</h1>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "967d209d-b78d-4244-acf5-116462793004",
   "metadata": {},
   "source": [
    "### 课程介绍: \n",
    "本课程使用pytorch框架, 完成NLP任务:古诗生成,使用的模型为 LSTM, 并训练了词向量, 支持随机古诗和藏头诗生成, 并且生成的古诗具有多变性。在课堂中会从0完成代码的编写，并分为多个文件，以便对比。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4ea3e70-08b6-4758-971b-4f2d2807b906",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aecdbbb6-e8ef-4d71-8aec-b6b30e3bcd4b",
   "metadata": {},
   "source": [
    "### 导包:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d162c71c-6dca-4c4a-830a-4e3e73a9dd4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from gensim.models.word2vec import Word2Vec\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a2a9016-fe98-44ed-988c-ae8a42e4f741",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53f2a818-a853-4fe8-a35d-ffa346d09202",
   "metadata": {},
   "source": [
    "### 生成切分文件:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "65f00d6c-e537-478c-88b7-b273a121ff85",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_text(file=\"poetry_7.txt\", train_num=6000):\n",
    "    all_data = open(file, \"r\", encoding=\"utf-8\").read()\n",
    "    with open(\"split_7.txt\", \"w\", encoding=\"utf-8\") as f:\n",
    "        split_data = \" \".join(all_data)\n",
    "        f.write(split_data)\n",
    "    return split_data[:train_num * 64]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48a29d1e-73cf-4430-b4e1-78b516ffcd23",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "201d6941-ef95-4dc2-a528-f8e3c2020df0",
   "metadata": {},
   "source": [
    "### 训练词向量:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dc0814c1-bfcd-42fe-9a67-95e5bc1a1e3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_vec(split_file=\"split_7.txt\", org_file=\"poetry_7.txt\", train_num=6000):\n",
    "    param_file = \"word_vec.pkl\"\n",
    "    org_data = open(org_file, \"r\", encoding=\"utf-8\").read().split(\"\\n\")[:train_num]\n",
    "    if os.path.exists(split_file):\n",
    "        all_data_split = open(split_file, \"r\", encoding=\"utf-8\").read().split(\"\\n\")[:train_num]\n",
    "    else:\n",
    "        all_data_split = split_text().split(\"\\n\")[:train_num]\n",
    "\n",
    "    if os.path.exists(param_file):\n",
    "        return org_data, pickle.load(open(param_file, \"rb\"))\n",
    "\n",
    "    models = Word2Vec(all_data_split, vector_size=128, workers=7, min_count=1)\n",
    "    pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, \"wb\"))\n",
    "    return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2eea4ec-a0e2-4857-aea9-5078e9aaeed2",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e1bbc2b-57e1-4cef-a9e7-92eced1879e7",
   "metadata": {},
   "source": [
    "### 构建数据集:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "91c761d3-8d79-4a9f-b208-4685e8593c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Poetry_Dataset(Dataset):\n",
    "    def __init__(self, w1, word_2_index, all_data):\n",
    "        self.w1 = w1\n",
    "        self.word_2_index = word_2_index\n",
    "        self.all_data = all_data\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        a_poetry = self.all_data[index]\n",
    "\n",
    "        a_poetry_index = [self.word_2_index[i] for i in a_poetry]\n",
    "        xs = a_poetry_index[:-1]\n",
    "        ys = a_poetry_index[1:]\n",
    "        xs_embedding = self.w1[xs]\n",
    "\n",
    "        return xs_embedding, np.array(ys).astype(np.int64)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.all_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b53b25d-7c27-46e9-9ed5-eeece1868313",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4873af03-36bb-44a5-9f24-7a16cb1be51c",
   "metadata": {},
   "source": [
    "### 模型构建:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d32e517d-646b-431d-9eb3-a74edc06bd6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Poetry_Model_lstm(nn.Module):\n",
    "    def __init__(self, hidden_num, word_size, embedding_num):\n",
    "        super().__init__()\n",
    "\n",
    "        self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "        self.hidden_num = hidden_num\n",
    "\n",
    "        self.lstm = nn.LSTM(input_size=embedding_num, hidden_size=hidden_num, batch_first=True, num_layers=2,\n",
    "                            bidirectional=False)\n",
    "        self.dropout = nn.Dropout(0.3)\n",
    "        self.flatten = nn.Flatten(0, 1)\n",
    "        self.linear = nn.Linear(hidden_num, word_size)\n",
    "        self.cross_entropy = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, xs_embedding, h_0=None, c_0=None):\n",
    "        if h_0 == None or c_0 == None:\n",
    "            h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))\n",
    "            c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))\n",
    "        h_0 = h_0.to(self.device)\n",
    "        c_0 = c_0.to(self.device)\n",
    "        xs_embedding = xs_embedding.to(self.device)\n",
    "        hidden, (h_0, c_0) = self.lstm(xs_embedding, (h_0, c_0))\n",
    "        hidden_drop = self.dropout(hidden)\n",
    "        hidden_flatten = self.flatten(hidden_drop)\n",
    "        pre = self.linear(hidden_flatten)\n",
    "\n",
    "        return pre, (h_0, c_0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70194fcf-7012-467a-b8bb-1cb1d92ef28b",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "676c499f-551e-4a5a-9464-644f2e05f143",
   "metadata": {},
   "source": [
    "### 自动生成古诗:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "70bd6aa6-6e97-4f73-8629-abbf13184f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_poetry_auto():\n",
    "    result = \"\"\n",
    "    word_index = np.random.randint(0, word_size, 1)[0]\n",
    "\n",
    "    result += index_2_word[word_index]\n",
    "    h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))\n",
    "    c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))\n",
    "\n",
    "    for i in range(31):\n",
    "        word_embedding = torch.tensor(w1[word_index][None][None])\n",
    "        pre, (h_0, c_0) = model(word_embedding, h_0, c_0)\n",
    "        word_index = int(torch.argmax(pre))\n",
    "        result += index_2_word[word_index]\n",
    "\n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33901082-0b12-4a85-a523-0b9745668be9",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd04be6a-9eaa-440b-9ea3-25a4f9a6eb48",
   "metadata": {},
   "source": [
    "### 藏头诗生成:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "75beb453-4ced-48cb-867d-9d3a862f99ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_poetry_acrostic():\n",
    "    input_text = input(\"请输入四个汉字：\")[:4]\n",
    "    result = \"\"\n",
    "    punctuation_list = [\"，\", \"。\", \"，\", \"。\"]\n",
    "    for i in range(4):\n",
    "        result += input_text[i]\n",
    "        h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))\n",
    "        c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))\n",
    "        word = input_text[i]\n",
    "        for j in range(6):\n",
    "            word_index = word_2_index[word]\n",
    "            word_embedding = torch.tensor(w1[word_index][None][None])\n",
    "            pre , (h_0,c_0) = model(word_embedding,h_0,c_0)\n",
    "            word = word_2_index[int(torch.argmax(pre))]\n",
    "            result += word\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc23dabc-7c12-453c-b3ce-f2d66f10b5bc",
   "metadata": {},
   "source": [
    "<br>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "581610cc-ea20-4484-93cb-e01d3c4048f4",
   "metadata": {},
   "source": [
    "### 主函数: 定义参数, 模型, 优化器, 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8638fa48-fbc5-44fa-b127-900567d1d012",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss:8.173\n",
      "渺声。。。成。。，。堪。，。。，。成。。。。。。。。。。。。定。\n",
      "loss:7.006\n",
      "寒。，，，，，，。，。，，，，，。。。。。。。，，，，。，。。。\n",
      "loss:6.838\n",
      "急。，，。，。，。。。，，。。。。。。。，，，，，。，，。。，。\n",
      "loss:6.742\n",
      "已。，，，。。。。。。，。，。。，。，。。。。。。。。。，，。。\n",
      "loss:6.703\n",
      "跃山山。。，。，，，。。，。，。。，，。，。。。，。。。。，。。\n",
      "loss:6.656\n",
      "翻有三，，，。。。。。。。。。。。，，。，。，。，，。。。。。。\n",
      "loss:6.642\n",
      "魂山三路，，，，，，，。，，，。。。，。。。。。。。，。。。，，\n",
      "loss:6.634\n",
      "嬴门风，三，，，，，，，。。。。，。。。。。。。。。。。，。。。\n",
      "loss:6.563\n",
      "评有风海，，，，，，，，，，，，，，，，，，。。。。。。，。，。\n",
      "loss:6.460\n",
      "鲛门风传斗，，，，，，。。。。。。。。。。。。。。。。。。。。。\n",
      "loss:6.406\n",
      "现门三树斗，，，，，，，。。。，。。。。。。。。。。。。。。。。\n",
      "loss:6.242\n",
      "傅门三海生天，，，，，。。。。。。一。。。。。。，。。。。。。。\n",
      "loss:6.107\n",
      "读山下车林天，，，山来。一一一。一来来。。来一。一山。。。山。。\n",
      "loss:5.943\n",
      "履山风车性三宣，一无一。一风花。一山风。一山天，一风一。一不风。\n",
      "loss:5.827\n",
      "珊色风光三林，，一来烟。一山青。一山一。一来风。一山如。一山时。\n",
      "loss:5.760\n",
      "眷台高路香有，，一山一山一来天，一山如声不来新。一来风人一来天。\n",
      "loss:5.687\n",
      "芃八若光海峰，，一须一。，一青。一来一。一无花。一来一。一天微。\n",
      "loss:5.612\n",
      "泚人梅光姓三宽，一山一。一天心。一来一山一不天。一来一山。不来。\n",
      "loss:5.580\n",
      "叉得梅路三天空，一步一。一天仙。一山一声一无花。未山一年。不来。\n",
      "loss:5.539\n",
      "岳阙别光不千过，砌山一觉一水仙。一山一山一云花。一来一山一不花，\n",
      "loss:5.538\n",
      "栈台若光海天天，一山一阁一水新。一来一风一无。，一教烟花一不花，\n",
      "loss:5.451\n",
      "雁有风光斗无天，一山如花旧天场。一来一年一无。，一山一年一人香，\n",
      "loss:5.444\n",
      "身门风光斗氏天，筠载编贵。天深。一是遥时不回仙，一来一人一无香，\n",
      "loss:5.413\n",
      "邸色风辣千冬，，一贲编阁旧粉中。一来黄年一无来，一来一。不时中，\n",
      "loss:5.399\n",
      "庵山风光斗天天，一山一水。天香。一家一生一无青。一来一年无不。。\n",
      "loss:5.365\n",
      "活有风光不天过，东章不溪为天天。一来一头知无。。海山一教不人花，\n",
      "loss:5.368\n",
      "峦山何路六一筵，一然新溪一天衣。一里一山一无。，一来一声一无违，\n",
      "loss:5.345\n",
      "略台风路能有过，偶歌一，一天心。一山一语，无公。一风一人无人夜，\n",
      "loss:5.341\n",
      "而门何涎接峰天，一然春溪一天心。一山一山，无里。一山一年一公青。\n",
      "loss:5.321\n",
      "艾色三树三有见，一楼一溪一天公。一来一人无无花。一来一影一人花。\n",
      "loss:5.289\n",
      "势帆高月香自天，一山清阁一天公。一山一人一禅青，金知一年无无老。\n",
      "loss:5.278\n",
      "缃有风路千海，，振山台溪色相公。一山一人一禅。。一是一声无不有。\n",
      "loss:5.263\n",
      "罕榜高车自天斗，一然此入一诗端。一山一日无无家。一山一风一黄。。\n",
      "loss:5.265\n",
      "岁门春车生峰天，振怀编溪一天场。一山文山一不花。一来一来常一寒。\n",
      "loss:5.220\n",
      "缳来风车见岛天，一山清阁一天端。一风一年天云人，海来烟人一无花。\n",
      "loss:5.163\n",
      "晦蛇觉月三峰天，踏山清溪一天场。一来青年一紫人，一山烟影一禅夜。\n",
      "loss:5.161\n",
      "治峰喧车三生水，一然身贵早一场。一来一年一染重，未山如人，公花。\n",
      "loss:5.150\n",
      "运得三光六其宽，振山载阁一水篇。一来海日一水有，万来烟人无人花。\n",
      "loss:5.136\n",
      "耨色何车海天天，一然清贵早超端。一是一山一紫果，一来一风天春香。\n",
      "loss:5.124\n",
      "牢朝风路千天天，一山清花旧本妍。遥山一年一圣句，一山一人留春。。\n",
      "loss:5.073\n",
      "晕色山结三海有，一眼编阁色超行。一是一年天不线，一得一时逢时会。\n",
      "loss:5.070\n",
      "男阙风光施注远，一山清天早超场。一山不有一紫在，一山一人常等青。\n",
      "loss:5.101\n",
      "郦色梅光三天宽，一山清花早超公。百是一语天染会，一来遥山一丈花。\n",
      "loss:5.138\n",
      "瑰来高光海天见，东锁富入一风留。一是盛年无圣老，一山一人不无违。\n",
      "loss:5.043\n",
      "世有波车碧天城，振名载花映笔妍。海阳一年一圣来，万山遥时一人家。\n",
      "loss:5.052\n",
      "柔台风光生峰远，一然无花一超衣。菜来不年常无句，一来一风初等花。\n",
      "loss:5.056\n",
      "课人帐结三若天，筠章一溪一天端。一是一年无圣客，一来遥人一无花。\n",
      "loss:5.008\n",
      "普色高空生名远，一山新山旧水端。他是一藤知梦化，一来遥人一等红。\n",
      "loss:4.975\n",
      "敉信梅路斗有垧，一名黄声一染衣。海风一年常不酒，一山遥山一春花。\n",
      "loss:4.942\n",
      "跎色梅车斗林菲，一山清溪映水微。遥是一人天筋句，一山应山一等铨。\n",
      "loss:4.909\n",
      "练帆若光生峰天，一然无气早超纹。一山纵年迷染果，金来遥人一等寒。\n",
      "loss:4.865\n",
      "增今风气海天宽，振落无天晚战场。一到当藤资野处，一山一花一尧。。\n",
      "loss:4.929\n",
      "蕤色三月暗丝菲，一然呵溪玉楼来。海来如人天染酒，一耳一时一禅花。\n",
      "loss:4.835\n",
      "卯迷何树斗芳宣，一山一阁一节心。澄知揽回判回籍，争今一年一云肩。\n",
      "loss:4.786\n",
      "绽朝三芬不峰家，一然新溪一迟端。一是纸德回染绿，一山谁劚一朦横。\n",
      "loss:4.782\n",
      "婴迷梅月三峰远，一仔烟花晚染场。一里安随秋百处，海山风鲤一尧花。\n",
      "loss:4.734\n",
      "浆今风路无天宽，匹山如天兴印衣。白是一层知无治，一鳞风山尚仔肩。\n",
      "loss:6.515\n",
      "圳来灞柳暗上舵，飞飞一天拥更。。。鸡耳高高身洲高，女无人未堂身，\n",
      "loss:5.755\n",
      "冻萧风光自天天，，入玉天。。。。。淼风初初，，。。马一。。。。生\n",
      "loss:5.269\n",
      "仲箐风功香生天，天然隐堆一。。。。阁日斜知。，。。来。一。。。。\n",
      "loss:5.039\n",
      "询蛇三道不峰界，精歌编风旧节。。青日年年。，，小来一劚。。。。。\n",
      "loss:4.918\n",
      "火榜御气六无履，一门清。一窥。。一里一人无圣，，海山烟人一。墙。\n",
      "loss:4.847\n",
      "金人风气六生之，东来无花色超妍。一是一人常圣，，更来风时一仔青。\n",
      "loss:4.855\n",
      "槔来作月千八远，一寿清分一天。。。日纵子无圣，，酒风一此一舜。。\n",
      "loss:4.810\n",
      "不事梅贩不峰远，一圃清花咸超微。海知盛子回野暗，未到仙此一仔。。\n",
      "loss:4.815\n",
      "臭来梅月不峰远，送龙身贵早超妍。青阳一竹双野，，愿夸放垒跨心林。\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-10-d0768a04d27e>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     18\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     19\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0me\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 20\u001b[1;33m         \u001b[1;32mfor\u001b[0m \u001b[0mbatch_index\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mbatch_x_embedding\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_y_index\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdataloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     21\u001b[0m             \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m             \u001b[0mbatch_x_embedding\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch_x_embedding\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32me:\\app\\python39\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    515\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    516\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 517\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    518\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    519\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32me:\\app\\python39\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    555\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    556\u001b[0m         \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 557\u001b[1;33m         \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    558\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    559\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32me:\\app\\python39\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32me:\\app\\python39\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-6-035360fdf124>\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m     11\u001b[0m         \u001b[0mxs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0ma_poetry_index\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m         \u001b[0mys\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0ma_poetry_index\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m         \u001b[0mxs_embedding\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mw1\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mxs\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     14\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     15\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mxs_embedding\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mys\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mint64\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "\n",
    "    all_data, (w1, word_2_index, index_2_word) = train_vec(train_num=300)\n",
    "\n",
    "    batch_size = 32\n",
    "    epochs = 1000\n",
    "    lr = 0.01\n",
    "    hidden_num = 128\n",
    "    word_size, embedding_num = w1.shape\n",
    "\n",
    "    dataset = Poetry_Dataset(w1, word_2_index, all_data)\n",
    "    dataloader = DataLoader(dataset, batch_size)\n",
    "\n",
    "    model = Poetry_Model_lstm(hidden_num, word_size, embedding_num)\n",
    "    model = model.to(model.device)\n",
    "\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "\n",
    "    for e in range(epochs):\n",
    "        for batch_index, (batch_x_embedding, batch_y_index) in enumerate(dataloader):\n",
    "            model.train()\n",
    "            batch_x_embedding = batch_x_embedding.to(model.device)\n",
    "            batch_y_index = batch_y_index.to(model.device)\n",
    "\n",
    "            pre, _ = model(batch_x_embedding)\n",
    "            loss = model.cross_entropy(pre, batch_y_index.reshape(-1))\n",
    "\n",
    "            loss.backward()  # 梯度反传 , 梯度累加, 但梯度并不更新, 梯度是由优化器更新的\n",
    "            optimizer.step()  # 使用优化器更新梯度\n",
    "            optimizer.zero_grad()  # 梯度清零\n",
    "\n",
    "            if batch_index % 100 == 0:\n",
    "                # model.eval()\n",
    "                print(f\"loss:{loss:.3f}\")\n",
    "                print(generate_poetry_auto())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4aafe08-343c-4d44-8950-80ab93aee48e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
